# AUTHOR: DING
# -*- codeing = utf-8 -*-
# @Time: 2024/2/2 9:41
# @Author: 86139
# @Site: 
# @File: 05-read_data.py
# @Software: PyCharm

from torch.utils.data import Dataset
from PIL import Image
import os
import cv2 as cv


# print(help(Dataset))
class MyData(Dataset):

    def __init__(self, root_dir, label_dir):
        self.root_dir = root_dir  # '../hymenoptera_data/train'
        self.label_dir = label_dir  # 'ants'
        self.path = os.path.join(self.root_dir, self.label_dir)
        self.imgs_path = os.listdir(self.path)

    def __getitem__(self, idx):
        img_name = self.imgs_path[idx]
        img_item_path = os.path.join(self.root_dir, self.label_dir, img_name)
        img = Image.open(img_item_path)
        label = self.label_dir
        return img, label

    def __len__(self):
        return len(self.imgs_path)


root_dir = '../hymenoptera_data/train'
ants_datalist = MyData(root_dir, 'ants')
bees_datalist = MyData(root_dir, 'bees')
train_dataset = ants_datalist + bees_datalist
img, label = train_dataset[124]
img.show()
